{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M-mTmXz9USNe",
        "outputId": "d2a37614-9c84-4460-af18-938faa296e5b"
      },
      "outputs": [],
      "source": [
        "!pip install -q --upgrade bitsandbytes accelerate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FW8nl3XRFrz0"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "\n",
        "import os\n",
        "import requests\n",
        "from IPython.display import Markdown, display, update_display\n",
        "from openai import OpenAI\n",
        "from google.colab import drive\n",
        "from huggingface_hub import login\n",
        "from google.colab import userdata\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
        "import torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xYW8kQYtF-3L"
      },
      "outputs": [],
      "source": [
        "hf_token = userdata.get('HF_TOKEN')\n",
        "login(hf_token, add_to_git_credential=True)\n",
        "\n",
        "DEEPSEEK = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
        "LLAMA = \"meta-llama/Llama-3.2-3B-Instruct\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "piEMmcSfMH-O"
      },
      "outputs": [],
      "source": [
        "system_message = \"\"\"\n",
        "You are an specialized tutor in creating flashcards about whatever topic the user decides to research.\n",
        "They need to be brief, with a short question and a short answer in the following markdown format example\n",
        "###TEMPLATE###\n",
        "# Flashcard 1\n",
        "<details>\n",
        "<summary>What is the capital of France?</summary>\n",
        "Paris\n",
        "</details>\n",
        "\n",
        "# Flashcard 2\n",
        "\n",
        "<details>\n",
        "<summary>What is the derivative of sin(x)?</summary>\n",
        "cos(x)\n",
        "</details>\n",
        "###TEMPLATE###\n",
        "\"\"\"\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UcRKUgcxMew6"
      },
      "outputs": [],
      "source": [
        "quant_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "HdQnWEzW3lzP"
      },
      "outputs": [],
      "source": [
        "# Wrapping everything in a function - and adding Streaming and generation prompts\n",
        "\n",
        "def generate(model, messages, quant=True, stream = True, max_new_tokens=500):\n",
        "  tokenizer = AutoTokenizer.from_pretrained(model)\n",
        "  tokenizer.pad_token = tokenizer.eos_token\n",
        "  input_ids = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(\"cuda\")\n",
        "  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=\"cuda\")\n",
        "  streamer = TextStreamer(tokenizer)\n",
        "  if quant:\n",
        "    model = AutoModelForCausalLM.from_pretrained(model, quantization_config=quant_config).to(\"cuda\")\n",
        "  else:\n",
        "    model = AutoModelForCausalLM.from_pretrained(model).to(\"cuda\")\n",
        "  if stream:\n",
        "    outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, streamer=streamer)\n",
        "  else:\n",
        "    outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens,)\n",
        "\n",
        "  response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "  return response\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 710,
          "referenced_widgets": [
            "c07d99864c17468091385a5449ad39db",
            "d1164091bab34a37a41a62ca66bd4635",
            "59a24e217f474d028436d95846c2fc17",
            "4776f1a85807460b9494377ce242887d",
            "82b8a20d2a8647faac84c46bd9e1248b",
            "991ebb206ead4e30818dc873fd5650ac",
            "e7d6ddd317c44472a9afeb63dee8d982",
            "28b2d565e7a0455eb362c02581604d3b",
            "2046de5490c8468da7c96f1528ab9a1c",
            "ba27365f3f124c359fa6e07c23af182c",
            "b139d8162b354551ad09c957cc842506"
          ]
        },
        "id": "jpM_jxeT4Bv3",
        "outputId": "75181c1d-8589-45ce-e5e0-d5974ada080c"
      },
      "outputs": [],
      "source": [
        "import gradio as gr\n",
        "import re\n",
        "\n",
        "def call_generate(model_name, topic, num_flashcards):\n",
        "    if model_name == \"LLAMA\":\n",
        "        model = LLAMA\n",
        "    elif model_name == \"DEEPSEEK\":\n",
        "        model = DEEPSEEK\n",
        "    else:\n",
        "        return \"Invalid model selected.\"\n",
        "\n",
        "    messages = [\n",
        "        {\"role\": \"system\", \"content\": system_message},\n",
        "        {\"role\": \"user\", \"content\": f\"I want to know more about {topic}. Please provide {num_flashcards} flashcards.\"}\n",
        "    ]\n",
        "\n",
        "    # Call your existing generate function\n",
        "    response = generate(model, messages, stream=False, max_new_tokens=2000)\n",
        "    text = re.sub(r'###TEMPLATE.*?###TEMPLATE', '', response, flags=re.DOTALL)\n",
        "\n",
        "    result = re.search(r\"(# Flashcard 1[\\s\\S]*</details>)\", text)\n",
        "\n",
        "    if result:\n",
        "      response = result.group(1)\n",
        "    else:\n",
        "      response\n",
        "    return response\n",
        "\n",
        "with gr.Blocks() as ui:\n",
        "    with gr.Row():\n",
        "        model_dropdown = gr.Dropdown(choices=[\"LLAMA\", \"DEEPSEEK\"], value=\"LLAMA\", label=\"Model\")\n",
        "    with gr.Row():\n",
        "        topic_selector = gr.Textbox(label=\"Type the topic you want flashcards:\", max_lines=1, max_length=50)\n",
        "        num_flashcards = gr.Slider(\n",
        "                    minimum=1,\n",
        "                    maximum=10,\n",
        "                    step=1,\n",
        "                    value=5,\n",
        "                    label=\"Nr. Flashcards\",\n",
        "                )\n",
        "    with gr.Row():\n",
        "        generate_button = gr.Button(\"Generate Flashcards\")\n",
        "    with gr.Row():\n",
        "        output = gr.Markdown()\n",
        "\n",
        "    # Hooking up events to callbacks\n",
        "    generate_button.click(\n",
        "        call_generate,\n",
        "        inputs=[model_dropdown, topic_selector, num_flashcards],\n",
        "        outputs=output\n",
        "    )\n",
        "\n",
        "ui.launch(inbrowser=True, debug=True)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "2046de5490c8468da7c96f1528ab9a1c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "28b2d565e7a0455eb362c02581604d3b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4776f1a85807460b9494377ce242887d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ba27365f3f124c359fa6e07c23af182c",
            "placeholder": "​",
            "style": "IPY_MODEL_b139d8162b354551ad09c957cc842506",
            "value": " 2/2 [00:35&lt;00:00, 15.99s/it]"
          }
        },
        "59a24e217f474d028436d95846c2fc17": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_28b2d565e7a0455eb362c02581604d3b",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2046de5490c8468da7c96f1528ab9a1c",
            "value": 2
          }
        },
        "82b8a20d2a8647faac84c46bd9e1248b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "991ebb206ead4e30818dc873fd5650ac": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b139d8162b354551ad09c957cc842506": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ba27365f3f124c359fa6e07c23af182c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c07d99864c17468091385a5449ad39db": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d1164091bab34a37a41a62ca66bd4635",
              "IPY_MODEL_59a24e217f474d028436d95846c2fc17",
              "IPY_MODEL_4776f1a85807460b9494377ce242887d"
            ],
            "layout": "IPY_MODEL_82b8a20d2a8647faac84c46bd9e1248b"
          }
        },
        "d1164091bab34a37a41a62ca66bd4635": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_991ebb206ead4e30818dc873fd5650ac",
            "placeholder": "​",
            "style": "IPY_MODEL_e7d6ddd317c44472a9afeb63dee8d982",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "e7d6ddd317c44472a9afeb63dee8d982": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
